# PyTorch DDP benchmarking on torchvision models.
#
# Inspired by https://github.com/pytorch/pytorch/tree/main/benchmarks/distributed/ddp
# 
# The output of the script will show the throughput for a 
# set of torchvision models for a given batch size and number
# GPUs. This is useful for end2end benchmarking of multi-node 
# performance for a given cloud. An example output from this
# task for a resnet is shown below:
#
# Benchmark: resnet101 with batch size 32
# 
#                             sec/iter    ex/sec      sec/iter    ex/sec      sec/iter    ex/sec      sec/iter    ex/sec
#    1 GPUs --   no ddp:  p50:  0.064s     501/s  p75:  0.064s     499/s  p90:  0.064s     497/s  p95:  0.065s     491/s
#    1 GPUs --    1M/1G:  p50:  0.064s     502/s  p75:  0.064s     502/s  p90:  0.064s     502/s  p95:  0.064s     501/s
#    2 GPUs --    1M/2G:  p50:  0.066s     486/s  p75:  0.066s     486/s  p90:  0.066s     484/s  p95:  0.066s     482/s
#    4 GPUs --    1M/4G:  p50:  0.068s     468/s  p75:  0.069s     464/s  p90:  0.070s     457/s  p95:  0.077s     417/s
#    8 GPUs --    1M/8G:  p50:  0.069s     465/s  p75:  0.069s     464/s  p90:  0.069s     463/s  p95:  0.069s     463/s
#   16 GPUs --    2M/8G:  p50:  0.089s     359/s  p75:  0.090s     356/s  p90:  0.091s     350/s  p95:  0.094s     340/s
#
# Usage:
#   sky launch -c torch_bench examples/torch_ddp_benchmark/torch_ddp_benchmark.yaml
#
#   # Terminate cluster after you're done
#   sky down torch_bench
name: torch-ddp-bench

num_nodes: 2

resources:
  accelerators: A100:8 # Make sure you use 8 GPU instances
  use_spot: True
  cloud: gcp

file_mounts: 
  ./torch_ddp_benchmark.py: ./examples/torch_ddp_benchmark/torch_ddp_benchmark.py

setup: |
  conda activate ddp
  if [ $? -eq 0 ]; then
    echo 'conda env exists'
  else
    conda create -n ddp python=3.10 -y
    conda activate ddp
  fi
  pip install torch torchvision

run: |
  conda activate ddp
  num_nodes=`echo "$SKYPILOT_NODE_IPS" | wc -l`
  master_addr=`echo "$SKYPILOT_NODE_IPS" | head -n1`
  LD_LIBRARY_PATH="" torchrun \
    --nproc_per_node 8 \
    --rdzv_id=1 --rdzv_endpoint=${master_addr}:1234 \
    --rdzv_backend=c10d --nnodes $num_nodes \
    torch_ddp_benchmark.py --distributed-backend nccl
